package websockethttp

import (
	"container/list"
	"encoding/json"
	"github.com/gorilla/websocket"
	"log"
	"strconv"
	"time"
)

type Client struct {
	Conn                  *websocket.Conn                         // 连接实例
	Name                  string                                  // 连接名称
	Url                   string                                  // 连接路径
	requestTimeout        int64                                   // 主动发送request时的超时时间
	pongTaskTime          int64                                   // 心跳包任务时间间隔（毫秒）
	isShowPongLogs        bool                                    // 心跳包任务日志
	isReconnectingStatus  bool                                    // 是否正在尝试重连
	pongTaskTicker        *time.Ticker                            // 心跳包发送任务
	requestTimeoutTicker  *time.Ticker                            // request timeout
	serverRequestFilters  *list.List                              // 被动接收request时的过滤器（别人给自己发送 fun: handlerRequest）
	serverResponseFilters *list.List                              // 被动响应response时的过滤器（自己给别人响应 fun: handlerRequest）
	clientRequestFilters  *list.List                              // 主动发送request时的过滤器（fun: sendRequest）
	clientResponseFilters *list.List                              // 主动接收response时的过滤器（fun: handlerResponse）
	processFuncMap        map[string]func(context *ClientContext) // Process方法保存
	requestCallback       map[string]*ClientRequestCallback       // 发送request时等待回调的函数
	requestCommonHeader   map[string]string                       // request common header 设置
	responseCommonHeader  map[string]string                       // response common header 设置
	writeBeforeFunc       func(data []byte) []byte                // 传输压缩算法
	readBeforeFunc        func(data []byte) []byte                // 传输解压算法
	defaultProcessFunc    func(context *ClientContext)            // 默认 Process 当找不到的时候调用这个处理
}

// ClientContext 封装请求上下文信息
type ClientContext struct {
	Client   *Client         `json:"client"`
	Request  *SocketRequest  `json:"request"`
	Response *SocketResponse `json:"response"`
	Extra    interface{}     `json:"extra"`
	Error    error           `json:"error"`
}

// ClientRequestCallback 封装request异步回调的函数信息
type ClientRequestCallback struct {
	Context    *ClientContext           `json:"context"`
	Callback   func(ctx *ClientContext) `json:"callback"`
	CreateTime int64                    `json:"create_time"`
}

// 打开连接
func (client *Client) openConnection(url string) (*websocket.Conn, error) {
	conn, _, err := websocket.DefaultDialer.Dial(url, nil)
	if err != nil {
		return nil, err
	}
	conn.SetCloseHandler(func(code int, text string) error {
		log.Printf("SetCloseHandler 服务端主动关闭连接：code(%v) msg(%v)", code, text)
		return nil
	})
	return conn, nil
}

// 重新连接
func (client *Client) retryConnection() {
	if !client.isReconnectingStatus { // 重连需要网络，有可能超过心跳包的时间间隔，加锁避免重复
		client.isReconnectingStatus = true
		conn, err := client.openConnection(client.Url) // 重连
		client.isReconnectingStatus = false

		if err != nil || conn == nil {
			log.Printf("RetryConnection: 重连失败 error(%v)", err)
		} else {
			log.Printf("RetryConnection: 重连成功 ConnName(%v)", client.Name)
			client.Conn = conn
		}
	}
}

func (client *Client) readMessage() {
	for {
		mt, body, err := client.Conn.ReadMessage() // 读取客户端消息（会阻塞线程）
		if err != nil {
			client.retryConnection()
		} else {
			switch mt {
			case websocket.TextMessage: // 文本消息(request走TextMessage)
				go func() {
					client.handlerRequest(body)
				}()
				break
			case websocket.BinaryMessage: // 字节消息(response走BinaryMessage)
				go func() {
					client.handlerResponse(body)
				}()
				break
			}
		}
	}
}

func (client *Client) handlerRequest(data []byte) {

	// 解压处理在这里加入
	msg := client.readBeforeFunc(data)

	// 创建 request
	request := new(SocketRequest)
	jsonUnmarshalErr := json.Unmarshal(msg, &request)
	if jsonUnmarshalErr != nil {
		log.Printf("handlerRequest: 服务端发送的数据错误 %v", string(msg))
		return
	}

	// 创建 response
	response := new(SocketResponse)
	response.Uid = request.Uid
	response.Header = make(map[string]string)
	// 绑定 common header
	for k, v := range client.responseCommonHeader {
		response.Header[k] = v
	}

	// 创建 context
	context := new(ClientContext)
	context.Client = client
	context.Request = request
	context.Response = response

	// 应用request过滤器处理
	for element := client.serverRequestFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(c *ClientContext) bool)(context) {
			return
		}
	}

	if request.Process != "" && len(request.Process) > 0 {
		processFunc, hok := client.processFuncMap[request.Process]
		if processFunc != nil && hok {
			processFunc(context)
		} else {
			client.defaultProcessFunc(context) // 调用默认process处理
		}
	}

	// 应用response过滤器处理
	for element := client.serverResponseFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(c *ClientContext) bool)(context) {
			return
		}
	}

	jsonByte, jsonMarshalErr := json.Marshal(&response)
	if jsonMarshalErr != nil {
		log.Printf("handlerRequest: response转为json异常 %v", jsonMarshalErr)
	}

	// 压缩处理在这里加入
	requestJsonByte := client.writeBeforeFunc(jsonByte)

	_ = writeResponseMessage(client.Conn, requestJsonByte)
}

func (client *Client) handlerResponse(data []byte) {

	// 解压处理在这里加入
	msg := client.readBeforeFunc(data)

	// 创建 response
	response := new(SocketResponse)
	jsonUnmarshalErr := json.Unmarshal(msg, response)
	if jsonUnmarshalErr != nil {
		log.Printf("handlerResponse: json转为response错误 %v", string(msg))
		return
	}

	rc, ok := client.requestCallback[response.Uid]
	if !ok || rc == nil {
		log.Printf("handlerResponse: 找不到对应的 callback(%v)", string(msg))
		return
	}

	defer delete(client.requestCallback, response.Uid)

	// 还原 response
	rc.Context.Response.Header = response.Header
	rc.Context.Response.Code = response.Code
	rc.Context.Response.Message = response.Message
	rc.Context.Response.Body = response.Body

	// 执行回调
	rc.Callback(rc.Context)

	// 应用response过滤器处理
	for element := client.clientResponseFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(c *ClientContext) bool)(rc.Context) {
			return
		}
	}
}

func (client *Client) sendRequest(request *SocketRequest, callback func(c *ClientContext)) {
	// 合并 common 到 header
	mergeHeader := make(map[string]string)
	for k, m := range client.requestCommonHeader {
		mergeHeader[k] = m
	}
	for k, m := range request.Header {
		mergeHeader[k] = m
	}
	request.Header = mergeHeader

	// 创建 response
	response := new(SocketResponse)
	response.Uid = request.Uid

	// 创建 context
	context := new(ClientContext)
	context.Client = client
	context.Request = request
	context.Response = response

	// 保存回调函数等待 handlerResponse 回调
	if callback != nil {
		client.requestCallback[request.Uid] = &ClientRequestCallback{
			Context:    context,
			Callback:   callback,
			CreateTime: time.Now().UnixMilli(),
		}
	}

	// 应用request过滤器处理
	for element := client.clientRequestFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(c *ClientContext) bool)(context) {
			return
		}
	}

	// struct 序列化成 json 字符串
	jsonByte, jsonMarshalErr := json.Marshal(request)
	if jsonMarshalErr != nil {
		log.Printf("handlerRequest: request转为json异常 %v", jsonMarshalErr)
	}

	// 压缩处理在这里加入
	requestJsonByte := client.writeBeforeFunc(jsonByte)

	// 执行发送操作并将错误信息返回
	context.Error = writeRequestMessage(client.Conn, requestJsonByte)
}

// 打开健康检查心跳任务
func (client *Client) openPongTask() {
	if client.pongTaskTime <= 0 {
		client.pongTaskTime = PongProcessTime
	}
	if client.pongTaskTicker != nil {
		client.pongTaskTicker.Stop()
	}
	client.pongTaskTicker = time.NewTicker(time.Millisecond * time.Duration(client.pongTaskTime))
	for range client.pongTaskTicker.C {
		if client.Conn != nil {
			go func() {
				client.SendMessage(
					PongProcessName, nil, strconv.FormatInt(time.Now().UnixMilli(), 10),
					func(c *ClientContext) {
						if c.Error != nil || c.Response.Code != PongProcessSuccess {
							client.retryConnection()
						}
						if client.isShowPongLogs {
							log.Printf("Health: code(%d) msg(%v)", c.Response.Code, c.Response.Message)
						}
					})
			}()
		}
	}
}

// 定时检查发送的 request 是否超时
func (client *Client) openConnTimeoutCheck() {
	if client.requestTimeout <= 0 {
		client.requestTimeout = RequestDefaultTimeout
	}
	if client.requestTimeoutTicker != nil {
		client.requestTimeoutTicker.Stop()
	}
	client.requestTimeoutTicker = time.NewTicker(time.Millisecond * time.Duration(client.requestTimeout))
	for range client.requestTimeoutTicker.C {
		nowTime := time.Now().UnixMilli()
		for key, callback := range client.requestCallback {
			if callback.CreateTime < (nowTime - client.requestTimeout) {
				go func() {
					callback.Context.Response.Code = RequestTimeoutCode
					callback.Context.Response.Message = RequestTimeoutMsg
					callback.Callback(callback.Context)
				}()
				delete(client.requestCallback, key)
			}
		}
	}
}

// =====================================================================================================================

func (client *Client) SetRequestCommonHeader(header map[string]string) {
	client.requestCommonHeader = header
}

func (client *Client) SetResponseCommonHeader(header map[string]string) {
	client.responseCommonHeader = header
}

func (client *Client) SetRequestTimeout(timeout int64) {
	client.requestTimeout = timeout
}

// SetPongTime 设置心跳间隔（毫秒）
func (client *Client) SetPongTime(pongTime int64) {
	client.pongTaskTime = pongTime
}

// ShowPongLogs 显示心跳日志
func (client *Client) ShowPongLogs(isShowPongLogs bool) {
	client.isShowPongLogs = isShowPongLogs
}

// IsActiveStatus 客户端是否活跃(即使当前连接是断开的但是系统会自动重连)
func (client *Client) IsActiveStatus() bool {
	return client.Conn != nil
}

// AddClientRequestFilterFunc 添加Request过滤器函数
func (client *Client) AddClientRequestFilterFunc(filter func(context *ClientContext) bool) {
	client.clientRequestFilters.PushBack(filter)
}

// AddClientResponseFilterFunc 添加Response过滤器函数
func (client *Client) AddClientResponseFilterFunc(filter func(context *ClientContext) bool) {
	client.clientResponseFilters.PushBack(filter)
}

// AddServerRequestFilterFunc 添加Request过滤器函数
func (client *Client) AddServerRequestFilterFunc(filter func(context *ClientContext) bool) {
	client.serverRequestFilters.PushBack(filter)
}

// AddServerResponseFilterFunc 添加Response过滤器函数
func (client *Client) AddServerResponseFilterFunc(filter func(context *ClientContext) bool) {
	client.serverResponseFilters.PushBack(filter)
}

// RegisterProcessFunc 注册请求处理器（所有的业务应该在Process中处理）
func (client *Client) RegisterProcessFunc(processFuncName string, processFunc func(context *ClientContext)) {
	client.processFuncMap[processFuncName] = processFunc
}

// SetWriteBeforeFunc 消息传送之前处理函数
func (client *Client) SetWriteBeforeFunc(fun func(data []byte) []byte) {
	client.writeBeforeFunc = fun
}

// SetReadBeforeFunc 消息接收之前处理函数
func (client *Client) SetReadBeforeFunc(fun func(data []byte) []byte) {
	client.readBeforeFunc = fun
}

// SetDefaultProcessFunc 默认Process函数
func (client *Client) SetDefaultProcessFunc(fun func(context *ClientContext)) {
	client.defaultProcessFunc = fun
}

// CloseConnection 关闭连接
func (client *Client) CloseConnection() {
	if client.Conn != nil {
		log.Printf("CloseConnChannel: error(%v)", client.Conn.Close())
		client.requestTimeoutTicker.Stop()
		client.pongTaskTicker.Stop()
		client.Conn = nil
	}
}

func (client *Client) SendMessage(process string, header map[string]string, body string, callback func(c *ClientContext)) {
	request := new(SocketRequest)
	request.Uid = NewMessageId()
	request.Process = process
	request.Header = header
	request.Body = body
	if request.Header != nil {
		request.Header = make(map[string]string)
	}
	client.sendRequest(request, callback)
}

func (client *Client) SendMessageAsync(process string, header map[string]string, body string) (*ClientContext, error) {
	var await = make(chan *ClientContext)
	defer safeCloseClientContextChannel(await)
	client.SendMessage(process, header, body, func(ctx *ClientContext) {
		await <- ctx
	})
	context := <-await
	return context, context.Error
}

// =====================================================================================================================

// ClientNew 创建连接
func ClientNew(url, clientName string) (*Client, error) {
	client := new(Client)

	conn, err := client.openConnection(url)
	if err != nil {
		return nil, err
	}

	client.Name = clientName
	client.Conn = conn
	client.Url = url
	client.isReconnectingStatus = false
	client.isShowPongLogs = false
	client.pongTaskTime = PongProcessTime
	client.requestTimeout = RequestDefaultTimeout
	client.clientRequestFilters = list.New()
	client.clientResponseFilters = list.New()
	client.serverRequestFilters = list.New()
	client.serverResponseFilters = list.New()
	client.processFuncMap = make(map[string]func(context *ClientContext))
	client.requestCallback = make(map[string]*ClientRequestCallback)
	client.requestCommonHeader = make(map[string]string)
	client.writeBeforeFunc = func(data []byte) []byte { return data }
	client.readBeforeFunc = func(data []byte) []byte { return data }
	client.defaultProcessFunc = func(context *ClientContext) {
		log.Printf("找不到对应的 Process %v", context.Request.Process)
	}

	go func() { client.openPongTask() }()

	go func() { client.openConnTimeoutCheck() }()

	go func() { client.readMessage() }()

	return client, nil
}
